# Schniepp Lab, 2018-2021

# Amino class to handle amino acid related calculations and related oscilators

import numpy as np
from Vector import Vector
from Atom import Atom

class Amino:
    def __init__(self, name, idx, coodinates, prev = None):
        self.idx = idx
        self.name = name
        
        """
        N is the nitrogen atom;
        CA is the C alpha atom; 
        C is the carbonyl carbon atom;
        O is the carbonyl oxygen atom;
        R is the functional group (R is a H atom in glycine);
        H is the hydrogen atom connects to N
        HA is the hydrogen atom connects to C alpha atom
        """
        
        self.N = Atom('N', x = coodinates[0][0], y = coodinates[0][1], z = coodinates[0][2])
        self.CA = Atom('C', x = coodinates[1][0], y = coodinates[1][1], z = coodinates[1][2])
        self.C = Atom('C', x = coodinates[2][0], y = coodinates[2][1], z = coodinates[2][2])
        self.O = Atom('O', x = coodinates[3][0], y = coodinates[3][1], z = coodinates[3][2])
        self.R = Atom('R', x = coodinates[4][0], y = coodinates[4][1], z = coodinates[4][2])
        self.H = Atom('H', x = coodinates[5][0], y = coodinates[5][1], z = coodinates[5][2])
        self.HA = Atom('H', x = coodinates[6][0], y = coodinates[6][1], z = coodinates[6][2])
        
        # Vectors to represent the CO, NCA(lpha), NH, CA(lpha)HA(lpha), and CA(lpha)R bonds
        self.CO = Vector(self.C.get_Cartesian(), self.O.get_Cartesian())
        self.CAC = Vector(self.CA.get_Cartesian(), self.C.get_Cartesian())
        self.NCA = Vector(self.N.get_Cartesian(), self.CA.get_Cartesian())
        self.NH = Vector(self.N.get_Cartesian(), self.H.get_Cartesian())
        self.CAHA = Vector(self.CA.get_Cartesian(), self.HA.get_Cartesian())
        self.CAR = Vector(self.CA.get_Cartesian(), self.R.get_Cartesian())

        # self.CAHAR_vec to represent the CH2 wagging and twisting
        vt = np.cross(self.CAHA.vec, self.CAR.vec)
        self.CAHAR_vec = Vector(x = vt[0], y = vt[1], z = vt[2])
        
        # 'prev' represents the previous amino acid.
        self.prev = prev
        # CN vector represents the CN streching mode and is not intialized, since the previous amino acid 
        # is needed. See set_CN below.
        # NHib vector represents the NH in-plane bending mode and is not intialized, since the previous amino
        # acid is needed. See set_NHib below.
        if self.prev:
            self.set_CN_NHib()
        else:
            self.CN = None
            self.NHib = None

    # Get the CN streching and NH in-plane bending oscilators.
    def set_CN_NHib(self):
        if not self.prev:
            print('No previous amino acid initialized, cannot initialize self.CN and self.NHib!')
        else:
            self.CN = Vector(self.prev.C.get_Cartesian(), self.N.get_Cartesian())

            # get the vector perpendicular to the plane spanned by CN (with previous amino acid) and NCA bonds
            # vt is perpendicular to self.CN and self.NCA
            vt = np.cross(self.CN.vec, self.NCA.vec)
            v_t = Vector(x = vt[0], y = vt[1], z = vt[2])
            
            # vtt is perpendicular to vt and self.NH, which is in the plane spanned by self.CN and self.NCA
            vtt = np.cross(v_t.vec, self.NH.vec)
            self.NHib = Vector(x = vtt[0], y = vtt[1], z = vtt[2])

    def set_HAb(self):
        pass
    
    # Given an oscillator and reference axis, calculates its P value
    def get_P(self, name, oscillator, axis):
        
        # first get the angle between the oscillator and the helix axis
        angle = oscillator.get_angle_between(axis)
        
        # get the intensity along Z (axis direction) and X (axis perpendicular direction)
        Z = (oscillator.norm * np.cos(angle)) ** 2
        X = (oscillator.norm * np.sin(angle)) ** 2
        P = (Z - 0.5 * X) / (Z + 0.5 * X)

        return P
        
    # calculate the P value for CO stretching
    def get_CO_P(self, h_axis):
        return self.get_P('CO', self.CO, h_axis)

    # calculate the P value for CA(lpha)C stretching
    def get_CAC_P(self, h_axis):
        return self.get_P('CAC', self.CAC, h_axis)
    
    # calculate the P value for NCA(lpha) stretching
    def get_NCA_P(self, h_axis):
        return self.get_P('NCA', self.NCA, h_axis)

    # calculate the P value for CAHAR_vec
    def get_CAHAR_P(self, h_axis):
        return self.get_P('CAHAR', self.CAHAR_vec, h_axis)

    # calculate the P value for CN stretching
    def get_CN_P(self, h_axis):
        if not self.CN:
            print('No self.CN initialized, cannot calculate the P value of CN stretching!')
            return
        else:
            return self.get_P('CN', self.CN, h_axis)
    
    # calculate the P value for NH in-plane-bending
    def get_NHib_P(self, h_axis):
        if not self.NHib:
            print('No self.NHib initialized, cannot calculate the P value of NHib!')
            return
        else:
            return self.get_P('NHib', self.NHib, h_axis)